#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <sys/types.h>
#include <pthread.h>
#include <assert.h>
#include <iostream>
#include <dirent.h>
#include <vector>
#include <cstring>
#include <fstream>
#include <algorithm>

#define num 100
#define thread_num 3

#define MAX_FILE 50
#define MAX_MEMORY 50 * 1024 * 1024 / sizeof(int64_t)


int data[num];        // 待排序数据
int thread_unit = 10; // 每个任务处理10个数据

std::vector<std::string> filenames;
std::vector<std::string> resultFiles;

void setFilenames()
{
    auto path = "./test_data/";
    auto dir = opendir(path);
    if (dir == NULL)
    {
        std::cerr << "path error------" << std::endl;
        exit(-1);
    }

    auto dp = readdir(dir);
    while (dp != NULL)
    {
        if (dp->d_type == DT_DIR || dp->d_name[0] == '.' || strstr(dp->d_name, ".txt") == NULL)
        {
        }
        else
        {
            std::string filename = dp->d_name;
            auto filepath = path + filename;
            filenames.push_back(filepath);
        }
        dp = readdir(dir);
    }
}

void setResultFiles()
{
    for (int i = 1; i <= filenames.size(); i++)
    {
        if (i % MAX_FILE == 0)
        {
            std::string resultFile = "tmp/tmp" + std::to_string(i / MAX_FILE) + ".txt";
            resultFiles.push_back(resultFile);
        }
        else if (i == filenames.size())
        {
            std::string resultFile = "tmp/tmp" + std::to_string(i / MAX_FILE + 1) + ".txt";
            resultFiles.push_back(resultFile);
        }
    }
}

void myMergeSort(int from, int length)
{
    // merge the file from filesize[form] to filesize[form + length - 1]
    std::vector<int64_t> sourceData;

    for (int index = from; index < from + length - 1; index++)
    {
        auto filename = filenames.at(index);
        // std::ifstream file;

        std::ifstream file(filename, std::ios::binary);

        if (!file)
        {
            std::cerr << "can't open file" << filename << std::endl;
            exit(-1);
        }

        std::size_t bytesRead = 0;

        constexpr std::size_t bufferSize = 4096;
        constexpr std::size_t maxNumbersPerBuffer = bufferSize / sizeof(int64_t);
        int64_t buffer[bufferSize];

        while (true) {
            file.read(reinterpret_cast<char*>(buffer), bufferSize);
            std::size_t bytesReadInBuffer = file.gcount();
            std::size_t numbersRead = bytesReadInBuffer / sizeof(int64_t);

            if (numbersRead == 0) {
                break;
            }

            for (std::size_t i = 0; i < numbersRead; ++i) {
                // 使用 buffer[i] 进行相应的操作，比如输出到屏幕或者存储到其他数据结构中
                sourceData.push_back(buffer[i]);
            }

            bytesRead += bytesReadInBuffer;
        }

        file.close();

    }

    std::sort(sourceData.begin(), sourceData.end());

    
    std::size_t resultIndex;
    resultIndex = from / MAX_FILE;
    std::ofstream file(resultFiles.at(resultIndex), std::ios::binary);

    if (!file)
    {
        std::cerr << "can't open file to write" << std::endl;
        exit(-1);
    }
    file.write(reinterpret_cast<char*>(sourceData.data()), sizeof(int64_t) * sourceData.size());

    file.close();
}



void finalMerge()
{    // assumpt that final Merge need to merge n files to 1 files, every time you should read MAX_MEMORY / n bytes
    auto numFile = resultFiles.size();
    // numFile + 1 generate all buffer for tmpfile and resultFile
    auto maxBufferSize = MAX_MEMORY / (numFile + 1);

    int64_t *result[numFile];
    int64_t *writeBuffer = new int64_t[maxBufferSize];

    int label[numFile + 1] = {};
    int length[numFile];

    for(int i = 0; i < numFile; i++) {
        result[i] = new int64_t[maxBufferSize];
        std::ifstream file(resultFiles.at(i), std::ios::binary);
        file.read(reinterpret_cast<char *>(result[i]), maxBufferSize * sizeof(int64_t));
        std::size_t bytesReadInBuffer = file.gcount();
        std::size_t numbersRead = bytesReadInBuffer / sizeof(int64_t);
        length[i] = numbersRead;
        if(numbersRead < maxBufferSize) {
            label[i] = 0;
        } else {
            // remember to support long file
            label[i] = bytesReadInBuffer;
        }
        file.close();
    }

    // for(int i = 0; i < numFile; i++) {
    //     std::cout << result[i][0] << std::endl;
    // }

    int hasWritten = 0;
    int index[numFile] = {};

    while(1) {
        
        int min = -1;
        for(int i = 0; i < numFile; i++) {
            if(index[i] < length[i]) {
                min = i;
                break;
            }
        }
        if(min == -1) {
            std::ofstream resultFile("result.txt", std::ios::binary | std::ios::app);
            if(!resultFile) {
                std::cerr << "file to write can not open " << std::endl; 
                exit(-1);
            }
            resultFile.write(reinterpret_cast<char *>(writeBuffer), sizeof(int64_t) * hasWritten);
            resultFile.close();
            break;
        }
        for(int i = min; i < numFile; i++) {
            if(index[i] >= length[i]) {
                continue;
            }
            if(result[i][index[i]] < result[min][index[min]]) {
                min = i;
            }
        }
        
        writeBuffer[hasWritten++] = result[min][index[min]];
        
        index[min]++;
        
        // std::cout << min << std::endl;
        if(index[min] >= length[min]) {
            // only use for debug
            // std::cout << min << std::endl;
            if(label[min] != 0) {
                // todo: read from one file again
            }
        }
        if(hasWritten == maxBufferSize) {
            std::ofstream resultFile("result.txt", std::ios::binary | std::ios::app);
            
            if(!resultFile) {
                std::cerr << "file to write can not open " << std::endl; 
                exit(-1);
            }
            resultFile.write(reinterpret_cast<char *>(writeBuffer), sizeof(int64_t) * hasWritten);
            resultFile.close();
            hasWritten = 0;
        }
        
    }
    
    
    
}

typedef struct worker
{
    int from;
    int length;
    struct worker *next;
} task;

/*线程池结构*/
typedef struct
{
    pthread_mutex_t queue_lock; // 线程池的互斥变量
    pthread_cond_t queue_ready; // 线程池的条件变量
    task *queue_head;           /*链表结构，线程池中所有等待任务*/
    int shutdown;               /*是否销毁线程池*/
    pthread_t *threadid;        // 线程ID的指针
    int max_thread_num;         /*线程池中允许的活动线程数目*/
    int cur_queue_size;         /*当前等待队列的任务数目*/
} thread_pool;

int add_task(void *(*process)(void *arg), void *arg);
void *thread_routine(void *arg);

static thread_pool *pool = NULL; // 刚开始指针为空
void pool_init(int max_thread_num)
{
    pool = (thread_pool *)malloc(sizeof(thread_pool));
    pthread_mutex_init(&(pool->queue_lock), NULL);
    pthread_cond_init(&(pool->queue_ready), NULL);
    pool->queue_head = NULL;
    pool->max_thread_num = max_thread_num;
    pool->cur_queue_size = 0;
    pool->shutdown = 0;
    pool->threadid = (pthread_t *)malloc(max_thread_num * sizeof(pthread_t));
    int i = 0;
    for (i = 0; i < max_thread_num; i++)
    {
        pthread_create(&(pool->threadid[i]), NULL, thread_routine, NULL); // 创建线程
    }
}

int add_task(int &i, int &j)
{                                                   /*向线程池中加入任务*/
    task *newworker = (task *)malloc(sizeof(task)); /*构造一个新任务*/
    newworker->next = NULL;
    newworker->from = i - j;
    newworker->length = j;
    pthread_mutex_lock(&(pool->queue_lock)); // 向任务队列中添加任务是互斥操作要上锁
    task *member = pool->queue_head;
    if (member != NULL)
    {
        while (member->next != NULL)
            member = member->next;
        member->next = newworker;
    }
    else
    {
        pool->queue_head = newworker; // 尾插法
    }
    assert(pool->queue_head != NULL);
    pool->cur_queue_size++;
    pthread_mutex_unlock(&(pool->queue_lock));
    pthread_cond_signal(&(pool->queue_ready)); /*条件已满足，向等待这个条件的线程发出信号，唤醒休眠的进程*/
    return 0;
}

int pool_destroy()
{
    if (pool->shutdown)
        return -1; /*防止两次调用*/
    pool->shutdown = 1;
    /*唤醒所有等待线程，线程池要销毁了*/
    pthread_cond_broadcast(&(pool->queue_ready));
    int i;
    for (i = 0; i < pool->max_thread_num; i++)
        pthread_join(pool->threadid[i], NULL); // 所有线程结束
    free(pool->threadid);
    /*销毁等待队列*/
    task *head = NULL;
    while (pool->queue_head != NULL)
    {
        head = pool->queue_head;
        pool->queue_head = pool->queue_head->next;
        free(head);
    }
    /*条件变量和互斥量也别忘了销毁*/
    pthread_mutex_destroy(&(pool->queue_lock));
    pthread_cond_destroy(&(pool->queue_ready));
    free(pool);
    /*销毁后指针置空是个好习惯*/
    pool = NULL;
    return 0;
}

void *thread_routine(void *arg)
{
    printf("starting thread 0x%lx\n", pthread_self());
    while (1)
    {
        pthread_mutex_lock(&(pool->queue_lock));
        while (pool->cur_queue_size == 0 && !pool->shutdown)
        { /*如果等待队列为0并且不销毁线程池，则处于阻塞状态; 注意pthread_cond_wait是一个原子操作，等待前会解锁，唤醒后会加锁*/
            printf("thread 0x%lx is waiting\n", pthread_self());
            pthread_cond_wait(&(pool->queue_ready), &(pool->queue_lock)); // 线程进行等待状态，同时解锁
        }
        /*线程池要销毁了*/
        if (pool->shutdown)
        {
            /*遇到break,continue,return等跳转语句，千万不要忘记先解锁*/
            pthread_mutex_unlock(&(pool->queue_lock));
            printf("thread 0x%lx will exit\n", pthread_self());
            pthread_exit(NULL);
        }

        assert(pool->cur_queue_size != 0); // 任务队列不为空
        assert(pool->queue_head != NULL);  //
        /*等待队列长度减去1，并取出链表中的头元素*/
        pool->cur_queue_size--;
        task *worker = pool->queue_head;
        pool->queue_head = worker->next;
        pthread_mutex_unlock(&(pool->queue_lock));
        printf("thread 0x%lx is starting to work on data from %d, length %d\n", pthread_self(), worker->from, worker->length);
        myMergeSort(worker->from, worker->length);
        free(worker);
        worker = NULL;
        sleep(1);
    }
    std::cout << "excute here!" << std::endl;
    pthread_exit(NULL);
}

int main(int argc, char **argv)
{

    setFilenames();
    setResultFiles();
    pool_init(thread_num); /*线程池中最多3个活动线程*/
    sleep(1);

    int j = 0;
    for (int i = 1; i <= filenames.size(); i++)
    {
        if (i % MAX_FILE == 0)
        {
            j = MAX_FILE;
            add_task(i, j);
            j = 0;
        }
        else if (i == filenames.size())
        {
            add_task(i, j);
        }
        j++;
    }
    sleep(20);
    pool_destroy();

    // clear result.txt
    std::ofstream resultFiles("result.txt", std::ios::binary);
    resultFiles.close();

    finalMerge();

    return 0;
}